{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Query Data using LLM\n",
    "\n",
    "Here is the overall RAG pipeline.   In this notebook, we will do steps (6), (7), (8), (9) and (10)\n",
    "- Importing data is already done in this notebook [rag_2_load_data_into_milvus.ipynb](rag_2_load_data_into_milvus.ipynb)\n",
    "- 👉 Step 6: Calculate embedding for user query\n",
    "- 👉 Step 7 & 8: Send the query to vector db to retrieve relevant documents\n",
    "- 👉 Step 9 & 10: Send the query and relevant documents (returned above step) to LLM and get answers to our query\n",
    "\n",
    "![image missing](media/rag-overview-2.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step-1: Configuration"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from my_config import MY_CONFIG"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step-2: Load .env file\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os,sys\n",
    "## Load Settings from .env file\n",
    "from dotenv import find_dotenv, dotenv_values\n",
    "\n",
    "# _ = load_dotenv(find_dotenv()) # read local .env file\n",
    "config = dotenv_values(find_dotenv())\n",
    "\n",
    "# debug\n",
    "# print (config)\n",
    "\n",
    "MY_CONFIG.REPLICATE_API_TOKEN = config.get('REPLICATE_API_TOKEN')\n",
    "\n",
    "if  MY_CONFIG.REPLICATE_API_TOKEN:\n",
    "    print (\"✅ config REPLICATE_API_TOKEN found\")\n",
    "else:\n",
    "    raise Exception (\"'❌ REPLICATE_API_TOKEN' is not set.  Please set it above to continue...\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step-3: Connect to Vector Database\n",
    "\n",
    "Milvus can be embedded and easy to use.\n",
    "\n",
    "<span style=\"color:blue;\">Note: If you encounter an error about unable to load database, try this: </span>\n",
    "\n",
    "- <span style=\"color:blue;\">In **vscode** : **restart the kernel** of previous notebook. This will release the db.lock </span>\n",
    "- <span style=\"color:blue;\">In **Jupyter**: Do `File --> Close and Shutdown Notebook` of previous notebook. This will release the db.lock</span>\n",
    "- <span style=\"color:blue;\">Re-run this cell again</span>\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pymilvus import MilvusClient\n",
    "\n",
    "milvus_client = MilvusClient(MY_CONFIG.DB_URI)\n",
    "\n",
    "print (\"✅ Connected to Milvus instance:\", MY_CONFIG.DB_URI)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step-4: Setup Embeddings\n",
    "\n",
    "Use the same embeddings we used to index our documents!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sentence_transformers import SentenceTransformer\n",
    "\n",
    "model = SentenceTransformer(MY_CONFIG.EMBEDDING_MODEL)\n",
    "\n",
    "def get_embeddings (str):\n",
    "    embeddings = model.encode(str, normalize_embeddings=True)\n",
    "    return embeddings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Test embeddings\n",
    "embeddings = get_embeddings('Paris 2024 Olympics')\n",
    "print ('embeddings len =', len(embeddings))\n",
    "print ('embeddings[:5] = ', embeddings[:5])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step-5: Vector Search and RAG"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get relevant documents using vector / sementic search\n",
    "\n",
    "def fetch_relevant_documents (query : str) :\n",
    "    search_res = milvus_client.search(\n",
    "        collection_name=MY_CONFIG.COLLECTION_NAME,\n",
    "        data = [get_embeddings(query)], # Use the `emb_text` function to convert the question to an embedding vector\n",
    "        limit=3,  # Return top 3 results\n",
    "        search_params={\"metric_type\": \"IP\", \"params\": {}},  # Inner product distance\n",
    "        output_fields=[\"text\"],  # Return the text field\n",
    "    )\n",
    "    # print (search_res)\n",
    "\n",
    "    retrieved_docs_with_distances = [\n",
    "        {'text': res[\"entity\"][\"text\"], 'distance' : res[\"distance\"]} for res in search_res[0]\n",
    "    ]\n",
    "    return retrieved_docs_with_distances\n",
    "## --- end ---\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# test relevant vector search\n",
    "import json\n",
    "import pprint\n",
    "\n",
    "question = \"What was the training data used to train Granite models?\"\n",
    "relevant_docs = fetch_relevant_documents(question)\n",
    "pprint.pprint(relevant_docs, indent=4)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step-6: Initialize LLM\n",
    "\n",
    "### LLM Choices at Replicate\n",
    "\n",
    "\n",
    "| Model                               | Publisher | Params | Description                                          |\n",
    "|-------------------------------------|-----------|--------|------------------------------------------------------|\n",
    "| ibm-granite/granite-3.0-8b-instruct | IBM       | 8 B    | IBM's newest Granite Model v3.0  (default)           |\n",
    "| ibm-granite/granite-3.0-2b-instruct | IBM       | 2 B    | IBM's newest Granite Model v3.0                      |\n",
    "| meta/meta-llama-3.1-405b-instruct   | Meta      | 405 B  | Meta's flagship 405 billion parameter language model |\n",
    "| meta/meta-llama-3-8b-instruct       | Meta      | 8 B    | Meta's 8 billion parameter language model            |\n",
    "| meta/meta-llama-3-70b-instruct      | Meta      | 70 B   | Meta's 70 billion parameter language model           |\n",
    "\n",
    "References \n",
    "\n",
    "- https://www.ibm.com/granite\n",
    "- https://www.llama.com/\n",
    "- https://replicate.com/  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ[\"REPLICATE_API_TOKEN\"] = MY_CONFIG.REPLICATE_API_TOKEN\n",
    "\n",
    "print ('Using model:', MY_CONFIG.LLM_MODEL)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import replicate\n",
    "\n",
    "def ask_LLM (question, relevant_docs):\n",
    "    context = \"\\n\".join(\n",
    "        [doc['text'] for doc in relevant_docs]\n",
    "    )\n",
    "    \n",
    "    max_new_tokens = 1024\n",
    "    \n",
    "    ## Truncate context, so we don't over shoot context window\n",
    "    context = context[:(MY_CONFIG.MAX_CONTEXT_WINDOW - max_new_tokens - 100)]\n",
    "    # print (\"context length:\", len(context))\n",
    "    # print ('============ context (this is the context supplied to LLM) ============')\n",
    "    # print (context)\n",
    "    # print ('============ end  context ============', flush=True)\n",
    "\n",
    "    system_prompt = \"\"\"\n",
    "    Human: You are an AI assistant. You are able to find answers to the questions from the contextual passage snippets provided.\n",
    "    \"\"\"\n",
    "    user_prompt = f\"\"\"\n",
    "    Use the following pieces of information enclosed in <context> tags to provide an answer to the question enclosed in <question> tags.\n",
    "    <context>\n",
    "    {context}\n",
    "    </context>\n",
    "    <question>\n",
    "    {question}\n",
    "    </question>\n",
    "    \"\"\"\n",
    "    # print (\"user_prompt length:\", len(user_prompt))\n",
    "\n",
    "    print ('============ here is the answer from LLM =====')\n",
    "    # The meta/meta-llama-3-8b-instruct model can stream output as it's running.\n",
    "    for event in replicate.stream(\n",
    "        MY_CONFIG.LLM_MODEL,\n",
    "        input={\n",
    "            \"top_k\": 1,\n",
    "            \"top_p\": 0.95,\n",
    "            \"prompt\": user_prompt,\n",
    "            #\"max_tokens\": MY_CONFIG.MAX_CONTEXT_WINDOW,\n",
    "            \"temperature\": 0.1,\n",
    "            \"system_prompt\": system_prompt,\n",
    "            \"length_penalty\": 1,\n",
    "            \"max_new_tokens\": max_new_tokens,\n",
    "            \"stop_sequences\": \"<|end_of_text|>,<|eot_id|>\",\n",
    "            \"prompt_template\": \"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\\n\\n{system_prompt}<|eot_id|><|start_header_id|>user<|end_header_id|>\\n\\n{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\\n\\n\",\n",
    "            \"presence_penalty\": 0,\n",
    "            \"log_performance_metrics\": False\n",
    "        },\n",
    "    ):\n",
    "        print(str(event), end=\"\")\n",
    "    ## ---\n",
    "    print ('\\n======  end LLM answer ======\\n', flush=True)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step-7: Query"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "question = \"What was the training data used to train Granite models?\"\n",
    "relevant_docs = fetch_relevant_documents(question)\n",
    "ask_LLM(question=question, relevant_docs=relevant_docs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "question = \"What is attention mechanism?\"\n",
    "relevant_docs = fetch_relevant_documents(question)\n",
    "ask_LLM(question=question, relevant_docs=relevant_docs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "question = \"When was the moon landing?\"\n",
    "relevant_docs = fetch_relevant_documents(question)\n",
    "ask_LLM(question=question, relevant_docs=relevant_docs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
